Coverage for cpprb/util.py: 72%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

58 statements  

1import numpy as np 

2 

3from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict 

4 

5def from_space(space,int_type,float_type): 

6 if isinstance(space,Discrete): 

7 return {"dtype": int_type,"shape": 1} 

8 elif isinstance(space,MultiDiscrete): 8 ↛ 9line 8 didn't jump to line 9, because the condition on line 8 was never true

9 return {"dtype": int_type,"shape": space.nvec.shape} 

10 elif isinstance(space,Box): 10 ↛ 12line 10 didn't jump to line 12, because the condition on line 10 was never false

11 return {"dtype": float_type,"shape": space.shape} 

12 elif isinstance(space,MultiBinary): 

13 return {"dtype": int_type, "shape": space.n} 

14 else: 

15 raise NotImplementedError(f"Error: Unknown Space {space}") 

16 

17def create_env_dict(env,*,int_type = None,float_type = None): 

18 """ 

19 Create ``env_dict`` from Open AI ``gym.space`` for ``ReplayBuffer`` constructor 

20 

21 Parameters 

22 ---------- 

23 env : gym.Env 

24 Environment 

25 int_type: np.dtype, optional 

26 Integer type. Default is ``np.int32`` 

27 float_type: np.dtype, optional 

28 Floating point type. Default is ``np.float32`` 

29 

30 Returns 

31 ------- 

32 env_dict : dict 

33 ``env_dict`` parameter for ``ReplayBuffer`` class. 

34 """ 

35 

36 int_type = int_type or np.int32 

37 float_type = float_type or np.float32 

38 

39 env_dict = {"rew" : {"shape": 1, "dtype": float_type}, 

40 "done": {"shape": 1, "dtype": float_type}} 

41 

42 observation_space = env.observation_space 

43 action_space = env.action_space 

44 

45 if isinstance(observation_space,Tuple): 

46 for i,s in enumerate(observation_space.spaces): 

47 env_dict[f"obs{i}"] = from_space(s,int_type,float_type) 

48 env_dict[f"next_obs{i}"] = from_space(s,int_type,float_type) 

49 elif isinstance(observation_space,Dict): 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true

50 for n, s in observation_space.spaces.items(): 

51 env_dict[n] = from_space(s,int_type,float_type) 

52 env_dict[f"next_{n}"] = from_space(s,int_type,float_type) 

53 else: 

54 env_dict["obs"] = from_space(observation_space,int_type,float_type) 

55 env_dict["next_obs"] = from_space(observation_space,int_type,float_type) 

56 

57 if isinstance(action_space,Tuple): 

58 for i,s in enumerate(action_space.spaces): 

59 env_dict[f"act{i}"] = from_space(s,int_type,float_type) 

60 elif isinstance(action_space,Dict): 60 ↛ 61line 60 didn't jump to line 61, because the condition on line 60 was never true

61 for n, s in action_space.spaces.items(): 

62 env_dict[n] = from_space(s,int_type,float_type) 

63 else: 

64 env_dict["act"] = from_space(action_space,int_type,float_type) 

65 

66 return env_dict 

67 

68def create_before_add_func(env): 

69 """ 

70 Create function to be used before ``ReplayBuffer.add`` 

71 

72 Parameters 

73 ---------- 

74 env : gym.Env 

75 Environment for before_func 

76 

77 Returns 

78 ------- 

79 before_add : callable 

80 Function to be used before ``ReplayBuffer.add`` 

81 """ 

82 def no_convert(name,v): 

83 return {f"{name}": v} 

84 

85 def convert_from_tuple(name,_tuple): 

86 return {f"{name}{i}": v for i,v in enumerate(_tuple)} 

87 

88 def convert_from_dict(name,_dict): 

89 return {f"{name}_{key}":v for key,v in _dict.items()} 

90 

91 

92 observation_space = env.observation_space 

93 action_space = env.action_space 

94 

95 

96 if isinstance(observation_space,Tuple): 

97 obs_func = convert_from_tuple 

98 elif isinstance(observation_space,Dict): 98 ↛ 99line 98 didn't jump to line 99, because the condition on line 98 was never true

99 obs_func = convert_from_dict 

100 else: 

101 obs_func = no_convert 

102 

103 if isinstance(action_space,Tuple): 

104 act_func = convert_from_tuple 

105 elif isinstance(action_space,Dict): 105 ↛ 106line 105 didn't jump to line 106, because the condition on line 105 was never true

106 act_func = convert_from_dict 

107 else: 

108 act_func = no_convert 

109 

110 def before_add(obs,act,next_obs,rew,done): 

111 return {**obs_func("obs",obs), 

112 **act_func("act",act), 

113 **obs_func("next_obs",next_obs), 

114 "rew": rew, 

115 "done": done} 

116 

117 return before_add